import argparse
import os
import subprocess

import tensorflow as tf
from utils.helpers import dump_frozen_graph, load_module
from ssd_detector.networks.mobilenet_ssd import MobileNetSSD


def parse_args():
  parser = argparse.ArgumentParser(description='Export model in IE format')
  parser.add_argument('path_to_config', help='Path to a config.py')
  parser.add_argument('mo', help="Path to model optimizer 'mo.py' script")
  return parser.parse_args()


def execute_tfmo(mo_py_path, frozen, config, input_shape=None, **kvwags):
  """
  This function dumps json configuration file and executes Model Optimizer for TensorFlow.
  As results it has converted to Inference Engine IR model in the same folder.

  :param mo_py_path: path to Model Optimizer mo.py
  :param frozen: path to frozen pb-file to convert to IE
  :param config: dictionary generated by SSD_Base.get_config_for_tfmo()
  :param input_shape: input shape to pass to Model Optimizer. This mostly need to overwrite batch size dimension
  """
  assert frozen.endswith('.pb.frozen')
  folder = os.path.dirname(frozen)
  name = os.path.splitext(frozen)[0].replace('.pb', '')

  json = frozen.replace('.pb.frozen', '.tfmo.json')
  with open(json, mode='w') as file:
    file.write(config['json'])

  scale = kvwags.get('scale', None)
  mean_values = kvwags.get('mean_values', None)

  params = (
    'python3', '-u', mo_py_path or 'mo.py',
    '--framework={}'.format('tf'),
    '--input_model={}'.format(frozen),
    '--output_dir={}'.format(folder),
    '--output={}'.format(','.join(config['cut_points'])),
    '--input_shape=[{}]'.format(','.join(map(str, input_shape))) if input_shape else '',
    '--scale={}'.format(scale) if scale else '',
    '--mean_values=[{}]'.format(','.join(map(str, mean_values))) if mean_values else '',
    '--model_name={}'.format(name),
    '--tensorflow_use_custom_operations_config={}'.format(json),
  )

  if mo_py_path:
    subprocess.call([p for p in params if p])
  else:
    print('\nPath to `mo.py` is not specified. Please provide correct path to Model Optimizer `mo.py` script')


def convert_to_ie(ssd, session, output_folder, mo_py_path, batch_size=None, **kvargs):
  """
  Single high-level function that converts current graph to IE model format

  :param ssd: Instance derived from SSDBase
  :param session: session with graph and initialized variables
  :param output_folder: absolute path to folder where to dump intermediate and final results of conversion
  :param mo_py_path: path to model optimizer for TensorFlow
  :param batch_size: batch_size to set for target model
  :param mean_values:
  :param scale:
  """
  if not os.path.isabs(output_folder):
    output_folder = os.path.join(os.getcwd(), output_folder)

  config = ssd.get_config_for_tfmo()
  graph_file = os.path.join(output_folder, 'graph.pb')
  frozen = dump_frozen_graph(session, graph_file, config['output_nodes'])

  input_shape = [batch_size] + list(ssd.input_shape[1:])
  execute_tfmo(mo_py_path, frozen, config, input_shape, **kvargs)


def export(cfg, tfmo):

  checkpoint_path = tf.train.latest_checkpoint(cfg.model_dir)

  detector_params = cfg.detector_params.copy()
  with tf.Session() as sess:
    input_tensor = tf.placeholder(dtype=tf.float32, shape=(None,) + tuple(cfg.input_shape))

    for unnecessary_param in ['initial_weights_path',
                              'learning_rate',
                              'optimizer',
                              'weights_decay_factor',
                              'collect_priors_summary']:
      if unnecessary_param in detector_params:
        del detector_params[unnecessary_param]

    ssd = MobileNetSSD(input_tensor=input_tensor, is_training=False, **detector_params)
    ssd.detection_output()

    train_param, _ = ssd.create_transform_parameters(width=cfg.input_shape[0], height=cfg.input_shape[1])

    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    mean_values = [train_param.mean_value for _ in range(3)]
    convert_to_ie(ssd, sess, os.path.join(cfg.model_dir, 'ie_model/'), tfmo, batch_size=1,
                  scale=1./train_param.scale, mean_values=mean_values)


def main(_):
  args = parse_args()
  cfg = load_module(args.path_to_config)
  export(cfg, args.mo)


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run(main)
